import carla
import numpy as np

from srunner.scenariomanager.carla_data_provider import CarlaDataProvider

from utils.geometry import check_2d_visibility
from utils.shapely_geometry import check_2d_visibility_shapely

RED = carla.TrafficLightState.Red
YELLOW = carla.TrafficLightState.Yellow
GREEN = carla.TrafficLightState.Green
TRAFFIC_LIGHT_STATES = {
    carla.TrafficLightState.Red: "red",
    carla.TrafficLightState.Yellow: "yellow",
    carla.TrafficLightState.Green: "green",
}
LM_COLORS = {
    carla.LaneMarkingColor.White: "white",
    carla.LaneMarkingColor.Blue: "blue",
    carla.LaneMarkingColor.Green: "green",
    carla.LaneMarkingColor.Red: "red",
    carla.LaneMarkingColor.Yellow: "yellow",
    carla.LaneMarkingColor.Other: "other",
    carla.LaneMarkingColor.Standard: "white",
}
LM_TYPES = {
    carla.LaneMarkingType.NONE: "solid",
    carla.LaneMarkingType.Broken: "dashed",
    carla.LaneMarkingType.Solid: "solid",
    carla.LaneMarkingType.SolidSolid: "double solid",
    carla.LaneMarkingType.SolidBroken: "solid-dashed",
    carla.LaneMarkingType.BrokenSolid: "dashed-solid",
    carla.LaneMarkingType.BrokenBroken: "double dashed",
}

class PartialObservableCaptioner:
    def __init__(self,
                 ego_vehicle,
                 perception_range=50,
                 ):
        self.ego_vehicle = ego_vehicle
        self.world = CarlaDataProvider.get_world()
        self.map = self.world.get_map()
        self.perception_range = perception_range
        self.traffic_light_status = {}

    def get_description(self):
        desc = ""
        # Describe the ego vehicle state
        desc += self.describe_ego_vehicle()
        # Describe the traffic light status
        desc += self.describe_traffic_light(self.ego_vehicle)
        # Describe the ego vehicle road geometry
        desc += self.describe_road_geometry()
        # Describe the current lane information
        desc += self.describe_lane()
        # Describe the other vehicles according to the partial observability
        desc += self.describe_other_vehicles()
        return desc

    def describe_ego_vehicle(self):
        # Describe the ego vehicle state
        desc = f"You are driving the Vehicle {self.ego_vehicle.id}, "
        desc += f"and it is a {self.ego_vehicle.type_id}. "
        desc += f"Your current speed is: {self.ego_vehicle.get_velocity().length() :.2f} m/s, "
        desc += f"and the speed limit is: {self.ego_vehicle.get_speed_limit() / 3.6 + 3 :.2f} m/s. "
        return desc

    def describe_traffic_light(self, vehicle):
        desc = ""
        if vehicle.id not in self.traffic_light_status:
            self.traffic_light_status[vehicle.id] = None
        if vehicle.is_at_traffic_light():
            self.traffic_light_status[vehicle.id] = TRAFFIC_LIGHT_STATES[vehicle.get_traffic_light_state()]

        vehicle_at_intersection = self.map.get_waypoint(vehicle.get_location()).is_intersection
        if (self.traffic_light_status[vehicle.id] == TRAFFIC_LIGHT_STATES[RED]
            and vehicle_at_intersection
            and vehicle.get_velocity().length() > 0.1
        ):
            # If any vehicle is violating the red traffic light and is in the intersection
            if vehicle.id == self.ego_vehicle.id:
                desc += "Your traffic light is red but you are at the intersection now. "
            else:
                desc += f"Vehicle {vehicle.id} from intersecting road is proceeding through the intersection now. "
        elif self.ego_vehicle.id == vehicle.id and self.traffic_light_status[vehicle.id] is not None:
            # Describe the traffic light status for the ego vehicle if it is close to thhe traffic light
            desc += f"The traffic light for you is currently: {self.traffic_light_status[vehicle.id]}. "
        return desc

    def describe_road_geometry(self):
        # Describe the road geometry, hard coded at this point
        # TODO: something like "at a four-way intersection" "on a ramp onto a 4-way highway"
        desc = ""
        return desc
    
    def describe_lane(self):
        # Describe the lane information in detail
        location = self.ego_vehicle.get_location()
        desc = ""
        wp = self.map.get_waypoint(location)

        if wp.is_intersection:
            if self.ego_vehicle.get_speed_limit() > 60:
                desc += "You are on a highway interchange junction. "
            else:
                desc += "You are at an intersection / junction. "
            return desc
        if self.ego_vehicle.get_speed_limit() > 60:
            desc += "You are on a highway. "

        ego_pose = self.ego_vehicle.get_transform()
        ego_location = self.ego_vehicle.get_location()
        ego_yaw = ego_pose.rotation.yaw
        wp_yaw = wp.transform.rotation.yaw
        lane_alignment = abs(ego_yaw - wp_yaw) < 90 or abs(ego_yaw - wp_yaw) > 270
        opposite_lane = "opposite " if not lane_alignment else ""
        desc += f"You are in the {opposite_lane}lane: {wp.lane_id} on road {wp.road_id}, "
        desc += f"and it is a {wp.lane_type} type lane. "
        if lane_alignment:
            desc += f"On your left is {LM_COLORS[wp.left_lane_marking.color]}, {LM_TYPES[wp.left_lane_marking.type]} line, "
            desc += f"On your right is {LM_COLORS[wp.right_lane_marking.color]}, {LM_TYPES[wp.right_lane_marking.type]} line, "
        else:
            desc += f"On your right is {LM_COLORS[wp.left_lane_marking.color]}, {LM_TYPES[wp.left_lane_marking.type]} line, "
            desc += f"On your left is {LM_COLORS[wp.right_lane_marking.color]}, {LM_TYPES[wp.right_lane_marking.type]} line, "

        wp_next = wp.next(1)[0]
        while wp_next and not wp_next.is_intersection:
            wp_next = wp_next.next(1)[0]
        if wp_next:
            distance_to_next_intersection = ego_location.distance(wp_next.transform.location)
            distance_to_next_intersection = np.round(distance_to_next_intersection, 2)
            if distance_to_next_intersection < 30:
                desc += f"The next intersection / interchange junction is {distance_to_next_intersection} m ahead. "
        return desc

    def describe_other_vehicles(self) -> str:
        """This function checks the visibility of other vehicles and describes them if they are visible"""
        # Visibility check
        obj_list = self.get_obj_list()
        ego_bbx = self.ego_vehicle.bounding_box
        ego_bbx.location = self.ego_vehicle.get_location()
        obj_visibility = check_2d_visibility_shapely(ego_bbx, obj_list, self.perception_range, threshold=0.4)
        desc = "Around you, there are other vehicles: \n"
        vehicles = self.world.get_actors().filter('vehicle.*')
        num_vehicle = 0
        for vehicle in vehicles:
            visible = False
            if vehicle.id == self.ego_vehicle.id:
                continue
            for i, obj in enumerate(obj_list):
                if obj.contains(vehicle.get_location(), carla.Transform()):
                    visible = obj_visibility[i]
                    break
            if visible:
                # Only describe the visible vehicles
                num_vehicle = num_vehicle + 1
                desc += f"({num_vehicle})"
                desc += self.describe_other_vehicle(vehicle)
                desc += "\n"
        return desc

    def get_obj_list(self):
        obj_list = []
        bbs = []
        bbs.extend(self.world.get_level_bbs(carla.CityObjectLabel.Pedestrians))
        bbs.extend(self.world.get_level_bbs(carla.CityObjectLabel.Car))
        bbs.extend(self.world.get_level_bbs(carla.CityObjectLabel.Bus))
        bbs.extend(self.world.get_level_bbs(carla.CityObjectLabel.Motorcycle))
        bbs.extend(self.world.get_level_bbs(carla.CityObjectLabel.Bicycle))
        bbs.extend(self.world.get_level_bbs(carla.CityObjectLabel.Truck))
        for bb in bbs:
            bb.extent.z += 1
            if self.ego_vehicle.get_location().distance(bb.location) < self.perception_range \
                and not bb.contains(self.ego_vehicle.get_location(), carla.Transform()):
                obj_list.append(bb)
        return obj_list

    def describe_transform(self, transform):
        ego_transform = self.ego_vehicle.get_transform()
        transform_in_ego_coord = get_transform_in_ego_coord(ego_transform, transform)
        relative_x = transform_in_ego_coord[0][3]
        relative_y = transform_in_ego_coord[1][3]
        desc = "It is {} meters {} you and {} meters to your {}. ".format(
                np.round(abs(relative_x),2), "ahead of" if relative_x > 0 else "behind",  
                np.round(abs(relative_y),2), "right" if relative_y > 0 else "left")
        return desc

    def describe_relative_velocity(self, vehicle):
        desc = ""
        ego_location = self.ego_vehicle.get_location()
        other_vehicle_location = vehicle.get_location()
        other_vehicle_velocity = vehicle.get_velocity()
        relative_location = other_vehicle_location - ego_location
        if relative_location.dot(other_vehicle_velocity) < 0 and other_vehicle_velocity.length() > 0.5 and relative_location.length() < 10:
            desc += "It is moving closer to you. "
        elif relative_location.dot(other_vehicle_velocity) > 0 and other_vehicle_velocity.length() > 0.5 and relative_location.length() > 3:
            desc += "It is moving away from you. "
        return desc
    
    def describe_other_vehicle(self, vehicle):
        desc = ""
        # Whether the ego vehicle is following lane
        ego_location = self.ego_vehicle.get_location()
        ego_wp = self.map.get_waypoint(ego_location)
        ego_pose = self.ego_vehicle.get_transform()
        ego_yaw = ego_pose.rotation.yaw
        ego_wp_yaw = ego_wp.transform.rotation.yaw
        ego_lane_alignment = 1 if abs(ego_yaw - ego_wp_yaw) < 90 or abs(ego_yaw - ego_wp_yaw) > 270 else -1

        # Describe the relative position of the other vehicle
        desc += f"Vehicle {vehicle.id} "
        desc += f"is a {vehicle.type_id}, "
        other_vehicle_velocity = vehicle.get_velocity()
        other_vehicle_speed = other_vehicle_velocity.length()
        if other_vehicle_speed < 0.1:
            desc += "and is stationary. "
        else:
            desc += f"traveling at speed: {other_vehicle_speed:.2f} m/s. "

        # Road description
        other_vehicle_wp = self.map.get_waypoint(vehicle.get_location())
        ego_vehicle_at_intersection = ego_wp.is_intersection
        other_vehicle_at_intersection = other_vehicle_wp.is_intersection
        if not ego_vehicle_at_intersection and not other_vehicle_at_intersection:
            ego_road_id = ego_wp.road_id
            other_vehicle_road_id = other_vehicle_wp.road_id
            if ego_road_id == other_vehicle_road_id:
                # On the same road, describe the relative lane
                desc += "It is on the same road as you. "
                ego_vehicle_lane_id = ego_wp.lane_id
                other_vehicle_lane_id = other_vehicle_wp.lane_id
                desc += f"It is in the lane: {other_vehicle_lane_id}, "
                if other_vehicle_lane_id == ego_vehicle_lane_id:
                    desc += "and it is in the same lane as you. "
                else:
                    if other_vehicle_lane_id * ego_vehicle_lane_id > 0:
                        # Same direction
                        lane_difference = abs(other_vehicle_lane_id) - abs(ego_vehicle_lane_id)
                        lane_direction = "same"
                        relative_lane_position = "right" if lane_difference * ego_lane_alignment > 0 else "left"
                        lane_difference = abs(lane_difference)
                    else:
                        # Opposite direction
                        lane_difference = other_vehicle_lane_id - ego_vehicle_lane_id
                        lane_direction = "opposite"
                        relative_lane_position = "right" if lane_difference * ego_lane_alignment * ego_vehicle_lane_id > 0 else "left"
                        lane_difference = abs(lane_difference) - 1
                    desc += f"and it is {abs(lane_difference)} lanes to your {relative_lane_position} in the {lane_direction} direction. "
            else:
                desc += f"It is on a different road {other_vehicle_road_id}. "
                other_vehicle_wp_yaw = other_vehicle_wp.transform.rotation.yaw
                if other_vehicle_wp_yaw - ego_wp_yaw > 45 and other_vehicle_wp_yaw - ego_wp_yaw < 135:
                    desc += "It is an intersecting road going right. "
                elif other_vehicle_wp_yaw - ego_wp_yaw < -45 and other_vehicle_wp_yaw - ego_wp_yaw > -135:
                    desc += "It is an intersecting road going left. " 
        elif ego_vehicle_at_intersection and not other_vehicle_at_intersection:
            desc += self.describe_relative_velocity(vehicle)
            desc += f"It is on road {other_vehicle_wp.road_id}"
            other_vehicle_wp_yaw = other_vehicle_wp.transform.rotation.yaw
            if other_vehicle_wp_yaw - ego_wp_yaw > 45 and other_vehicle_wp_yaw - ego_wp_yaw < 135:
                desc += ", which is an intersecting road going right. "
            elif other_vehicle_wp_yaw - ego_wp_yaw < -45 and other_vehicle_wp_yaw - ego_wp_yaw > -135:
                desc += ", which is an intersecting road going left. "
            else:
                desc += ". "
        else:
            if vehicle.get_speed_limit() > 60:
                desc += "It is on a highway interchange junction. "
            else:
                desc += "It is at an intersection / junction. "
            desc += self.describe_relative_velocity(vehicle)

        # Describe the traffic light status
        desc += self.describe_traffic_light(vehicle)
        # Describe relative position
        desc += f"{self.describe_transform(vehicle.get_transform())} "
        return desc

def get_transform_in_ego_coord(ego_transform, target_transform):
        ego_transform_inverse_matrix = ego_transform.get_inverse_matrix()
        target_transform_matrix = target_transform.get_matrix()
        return np.matmul(ego_transform_inverse_matrix, target_transform_matrix)
